{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f2254819-deaf-48ba-848c-471f51ee1221",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4fe03089-ec4d-4bda-9b02-46cb320e516a",
   "metadata": {},
   "outputs": [],
   "source": [
    "origin_checkpoint_path = '/mnt/cache/zhujinguo/codes/UniPerceiver/work_dirs/deepspeed_moe/BERT_L12_H768_experiments/16task_90k_bertbase_lr1e-3_wd0.2_gc0.1_prenorm_warm10k_layerscale1e-3_uniformdp0.1_maeinit_fixedpos_torchfp16_unifieddataset_changeweight_stage2_224size/bertbase_womoe_pretrain2/89999/mp_rank_00_model_states.pt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "edc282cd-8345-4321-b0a0-3e21d64bfa35",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['module', 'buffer_names', 'optimizer', 'lr_scheduler', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'global_samples', 'dp_world_size', 'mp_world_size', 'ds_config', 'ds_version', 'iteration'])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "origin_checkpoint = torch.load(origin_checkpoint_path, 'cpu')\n",
    "origin_checkpoint.keys()\n",
    "# list(origin_checkpoint['module'].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "79d9f479-3144-4791-82ba-71fec264aa29",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "201"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(list(origin_checkpoint['module'].keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3452947d-4593-4431-a772-3a8ad4882c03",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['model', 'trainer', 'amp_scaler', 'scheduler', 'iteration'])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# new_checkpoint_path = 'new_exp/model_Epoch_00160_Iter_0000159.pth'\n",
    "# new_checkpoint = torch.load(new_checkpoint_path, 'cpu')\n",
    "# new_checkpoint.keys()\n",
    "# list(new_checkpoint['model'].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ffdcf5c5-ffd4-4379-89d7-37ce05c4c0f2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "41"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# len(list(new_checkpoint['model'].keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fec7a303-a30c-4e92-9452-b534a52d67e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping_dict = {\n",
    "\n",
    "    'encoder.': 'fused_encoder.',\n",
    "    'attention.self.qkv_proj.weight': 'self_attn.in_proj_weight',\n",
    "    'attention.self.qkv_proj.bias': 'self_attn.in_proj_bias',\n",
    "    'attention.output.dense': 'self_attn.out_proj',\n",
    "    'attention_output.residual_scale': 'gamma_1',\n",
    "    'ffn.dense.': 'linear1.',\n",
    "    'ffn.dense2.': 'linear2.',\n",
    "    'ffn_output.residual_scale': 'gamma_2',\n",
    "    'LayerNormModules.0.': 'norm1.',\n",
    "    'LayerNormModules.1.': 'norm2.',\n",
    "    'predictor.': 'loss_prepare.',\n",
    "    \n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "897ff2f0-1232-4d25-9c13-7ea9568da362",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_checkpoint = { } \n",
    "\n",
    "module_checkpoint = origin_checkpoint['module']\n",
    "\n",
    "for k, v in module_checkpoint.items():\n",
    "    if k.endswith('residual_scale'):\n",
    "        v.squeeze_(1).squeeze_(0)\n",
    "    if k.startswith('visual_embed'):\n",
    "        continue\n",
    "    for origin_str, target_str in mapping_dict.items():\n",
    "        if origin_str in k:\n",
    "            k = k.replace(origin_str, target_str)\n",
    "    \n",
    "    new_checkpoint[k] = v.float()\n",
    "\n",
    "# merge type embedding in video_embed \n",
    "new_checkpoint['video_embed.embeddings.bias'] = new_checkpoint['video_embed.embeddings.bias'] + new_checkpoint['video_embed.embeddings_type.weight'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3c26719f-7451-4c0a-85c3-640c820dfe98",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "torch.save({ 'model': new_checkpoint}, '/mnt/lustre/zhujinguo/codes/Uni-Perceiver/work_dirs/pretrained_models/uni-perceiver-base-L12-H768-224size-pretrained.pth')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
